{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datawhale 气象海洋预测-Task5 模型建立之 SA-ConvLSTM\n",
    "\n",
    "本次任务我们将学习来自TOP选手“吴先生的队伍”的建模方案，该方案中采用的模型是SA-ConvLSTM。\n",
    "\n",
    "前两个TOP方案中选择将赛题看作一个多输出的任务，通过构建神经网络直接输出24个nino3.4预测值，这种思路的问题在于，序列问题往往是时序依赖的，当我们采用多输出的方法时其实把这24个nino3.4预测值看作是完全独立的，但是实际上它们之间是存在序列依赖的，即每个预测值往往受上一个时间步的预测值的影响。因此，在这次的TOP方案中，采用Seq2Seq结构来考虑输出预测值的序列依赖性。\n",
    "\n",
    "Seq2Seq结构包括Encoder（编码器）和Decoder（解码器）两部分，Encoder部分将输入序列编码成一个向量，Decoder部分对向量进行解码，输出一个预测序列。要将Seq2Seq结构应用于不同的序列问题，关键在于每一个时间步所使用的Cell。我们之前说到，挖掘空间信息通常会采用CNN，挖掘时间信息通常会采用RNN或LSTM，将二者结合在一起就得到了时空序列领域的经典模型——ConvLSTM，我们本次要学习的SA-ConvLSTM模型是对ConvLSTM模型的改进，在其基础上引入了自注意力机制来提高模型对于长期空间依赖关系的挖掘能力。\n",
    "\n",
    "另外与前两个TOP方案所不同的一点是，该TOP方案没有直接预测Nino3.4指数，而是通过预测sst来间接求得Nino3.4指数序列。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 学习目标\n",
    "1. 学习TOP方案的模型构建方法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 内容介绍\n",
    "1. 数据处理\n",
    "    - 数据扁平化\n",
    "    - 空值填充\n",
    "    - 构造数据集\n",
    "2. 模型构建\n",
    "    - 构造评估函数\n",
    "    - 模型构造\n",
    "    - 模型训练\n",
    "    - 模型评估\n",
    "3. 总结"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 代码示例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据处理\n",
    "该TOP方案的数据处理主要包括三部分：\n",
    "1. 数据扁平化。\n",
    "2. 空值填充。\n",
    "3. 构造数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:04:34.698663Z",
     "iopub.status.busy": "2021-11-29T03:04:34.697133Z",
     "iopub.status.idle": "2021-11-29T03:04:37.035400Z",
     "shell.execute_reply": "2021-11-29T03:04:37.034767Z",
     "shell.execute_reply.started": "2021-11-29T01:02:51.883602Z"
    },
    "papermill": {
     "duration": 2.370278,
     "end_time": "2021-11-29T03:04:37.035673",
     "exception": false,
     "start_time": "2021-11-29T03:04:34.665395",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import netCDF4 as nc\n",
    "import random\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "\n",
    "from sklearn.metrics import mean_squared_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:04:37.102995Z",
     "iopub.status.busy": "2021-11-29T03:04:37.102144Z",
     "iopub.status.idle": "2021-11-29T03:04:37.107646Z",
     "shell.execute_reply": "2021-11-29T03:04:37.107161Z",
     "shell.execute_reply.started": "2021-11-29T01:02:54.06493Z"
    },
    "papermill": {
     "duration": 0.040737,
     "end_time": "2021-11-29T03:04:37.107761",
     "exception": false,
     "start_time": "2021-11-29T03:04:37.067024",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 固定随机种子\n",
    "SEED = 22\n",
    "\n",
    "def seed_everything(seed=42):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    \n",
    "seed_everything(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:04:37.222525Z",
     "iopub.status.busy": "2021-11-29T03:04:37.221100Z",
     "iopub.status.idle": "2021-11-29T03:04:37.225844Z",
     "shell.execute_reply": "2021-11-29T03:04:37.226442Z",
     "shell.execute_reply.started": "2021-11-29T01:02:54.074875Z"
    },
    "papermill": {
     "duration": 0.090198,
     "end_time": "2021-11-29T03:04:37.226602",
     "exception": false,
     "start_time": "2021-11-29T03:04:37.136404",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CUDA is available!  Training on GPU ...\n"
     ]
    }
   ],
   "source": [
    "# 查看CUDA是否可用\n",
    "train_on_gpu = torch.cuda.is_available()\n",
    "\n",
    "if not train_on_gpu:\n",
    "    print('CUDA is not available.  Training on CPU ...')\n",
    "else:\n",
    "    print('CUDA is available!  Training on GPU ...')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:04:37.353082Z",
     "iopub.status.busy": "2021-11-29T03:04:37.352143Z",
     "iopub.status.idle": "2021-11-29T03:04:37.432852Z",
     "shell.execute_reply": "2021-11-29T03:04:37.434332Z",
     "shell.execute_reply.started": "2021-11-28T10:13:13.644947Z"
    },
    "papermill": {
     "duration": 0.179146,
     "end_time": "2021-11-29T03:04:37.434792",
     "exception": false,
     "start_time": "2021-11-29T03:04:37.255646",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 读取数据\n",
    "\n",
    "# 存放数据的路径\n",
    "path = '/kaggle/input/ninoprediction/'\n",
    "soda_train = nc.Dataset(path + 'SODA_train.nc')\n",
    "soda_label = nc.Dataset(path + 'SODA_label.nc')\n",
    "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n",
    "cmip_label = nc.Dataset(path + 'CMIP_label.nc')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 数据扁平化\n",
    "采用滑窗构造数据集。该方案中只使用了sst特征，且只使用了lon值在[90, 330]范围内的数据，可能是为了节约计算资源。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:04:37.548239Z",
     "iopub.status.busy": "2021-11-29T03:04:37.546737Z",
     "iopub.status.idle": "2021-11-29T03:04:37.551951Z",
     "shell.execute_reply": "2021-11-29T03:04:37.553081Z",
     "shell.execute_reply.started": "2021-11-27T13:38:32.620904Z"
    },
    "papermill": {
     "duration": 0.065069,
     "end_time": "2021-11-29T03:04:37.553274",
     "exception": false,
     "start_time": "2021-11-29T03:04:37.488205",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def make_flatted(train_ds, label_ds, info, start_idx=0):\n",
    "    # 只使用sst特征\n",
    "    keys = ['sst']\n",
    "    label_key = 'nino'\n",
    "    # 年数\n",
    "    years = info[1]\n",
    "    # 模式数\n",
    "    models = info[2]\n",
    "    \n",
    "    train_list = []\n",
    "    label_list = []\n",
    "    \n",
    "    # 将同种模式下的数据拼接起来\n",
    "    for model_i in range(models):\n",
    "        blocks = []\n",
    "        \n",
    "        # 对每个特征，取每条数据的前12个月进行拼接，只使用lon值在[90, 330]范围内的数据\n",
    "        for key in keys:\n",
    "            block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12, :, 19: 67].reshape(-1, 24, 48, 1).data\n",
    "            blocks.append(block)\n",
    "        \n",
    "        # 将所有特征在最后一个维度上拼接起来\n",
    "        train_flatted = np.concatenate(blocks, axis=-1)\n",
    "        \n",
    "        # 取12-23月的标签进行拼接，注意加上最后一年的最后12个月的标签（与最后一年12-23月的标签共同构成最后一年前12个月的预测目标）\n",
    "        label_flatted = np.concatenate([\n",
    "            label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n",
    "            label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n",
    "        ], axis=0)\n",
    "        \n",
    "        train_list.append(train_flatted)\n",
    "        label_list.append(label_flatted)\n",
    "        \n",
    "    return train_list, label_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:04:37.661954Z",
     "iopub.status.busy": "2021-11-29T03:04:37.660977Z",
     "iopub.status.idle": "2021-11-29T03:05:11.515409Z",
     "shell.execute_reply": "2021-11-29T03:05:11.515853Z",
     "shell.execute_reply.started": "2021-11-27T13:38:33.844185Z"
    },
    "papermill": {
     "duration": 33.912013,
     "end_time": "2021-11-29T03:05:11.516001",
     "exception": false,
     "start_time": "2021-11-29T03:04:37.603988",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1, 1200, 24, 48, 1), (15, 1812, 24, 48, 1), (17, 1680, 24, 48, 1))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "soda_info = ('soda', 100, 1)\n",
    "cmip6_info = ('cmip6', 151, 15)\n",
    "cmip5_info = ('cmip5', 140, 17)\n",
    "\n",
    "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n",
    "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n",
    "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n",
    "\n",
    "# 得到扁平化后的数据维度为（模式数×序列长度×纬度×经度×特征数），其中序列长度=年数×12\n",
    "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 空值填充\n",
    "将空值填充为0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:11.638562Z",
     "iopub.status.busy": "2021-11-29T03:05:11.637553Z",
     "iopub.status.idle": "2021-11-29T03:05:11.644302Z",
     "shell.execute_reply": "2021-11-29T03:05:11.644742Z",
     "shell.execute_reply.started": "2021-11-27T13:39:22.665855Z"
    },
    "papermill": {
     "duration": 0.040786,
     "end_time": "2021-11-29T03:05:11.644893",
     "exception": false,
     "start_time": "2021-11-29T03:05:11.604107",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of null in soda_trains after fillna: 0\n"
     ]
    }
   ],
   "source": [
    "# 填充SODA数据中的空值\n",
    "soda_trains = np.array(soda_trains)\n",
    "soda_trains_nan = np.isnan(soda_trains)\n",
    "soda_trains[soda_trains_nan] = 0\n",
    "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:11.709054Z",
     "iopub.status.busy": "2021-11-29T03:05:11.707767Z",
     "iopub.status.idle": "2021-11-29T03:05:11.862744Z",
     "shell.execute_reply": "2021-11-29T03:05:11.863294Z",
     "shell.execute_reply.started": "2021-11-27T13:39:24.110039Z"
    },
    "papermill": {
     "duration": 0.18937,
     "end_time": "2021-11-29T03:05:11.863480",
     "exception": false,
     "start_time": "2021-11-29T03:05:11.674110",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of null in cmip6_trains after fillna: 0\n"
     ]
    }
   ],
   "source": [
    "# 填充CMIP6数据中的空值\n",
    "cmip6_trains = np.array(cmip6_trains)\n",
    "cmip6_trains_nan = np.isnan(cmip6_trains)\n",
    "cmip6_trains[cmip6_trains_nan] = 0\n",
    "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:11.927752Z",
     "iopub.status.busy": "2021-11-29T03:05:11.925353Z",
     "iopub.status.idle": "2021-11-29T03:05:12.091117Z",
     "shell.execute_reply": "2021-11-29T03:05:12.091855Z",
     "shell.execute_reply.started": "2021-11-27T13:39:24.520724Z"
    },
    "papermill": {
     "duration": 0.197975,
     "end_time": "2021-11-29T03:05:12.092014",
     "exception": false,
     "start_time": "2021-11-29T03:05:11.894039",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of null in cmip6_trains after fillna: 0\n"
     ]
    }
   ],
   "source": [
    "# 填充CMIP5数据中的空值\n",
    "cmip5_trains = np.array(cmip5_trains)\n",
    "cmip5_trains_nan = np.isnan(cmip5_trains)\n",
    "cmip5_trains[cmip5_trains_nan] = 0\n",
    "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 构造数据集\n",
    "构造训练和验证集。注意这里取每条输入数据的序列长度是38，这是因为输入sst序列长度是12，输出sst序列长度是26，在训练中采用teacher forcing策略（这个策略会在之后的模型构造时详细说明），因此这里在构造输入数据时包含了输出sst序列的实际值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:12.165242Z",
     "iopub.status.busy": "2021-11-29T03:05:12.164045Z",
     "iopub.status.idle": "2021-11-29T03:05:12.480257Z",
     "shell.execute_reply": "2021-11-29T03:05:12.479767Z",
     "shell.execute_reply.started": "2021-11-27T13:39:25.418945Z"
    },
    "papermill": {
     "duration": 0.361254,
     "end_time": "2021-11-29T03:05:12.480405",
     "exception": false,
     "start_time": "2021-11-29T03:05:12.119151",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 构造训练集\n",
    "\n",
    "X_train = []\n",
    "y_train = []\n",
    "# 从CMIP5的17种模式中各抽取100条数据\n",
    "for model_i in range(17):\n",
    "    samples = np.random.choice(cmip5_trains.shape[1]-38, size=100)\n",
    "    for ind in samples:\n",
    "        X_train.append(cmip5_trains[model_i, ind: ind+38])\n",
    "        y_train.append(cmip5_labels[model_i][ind: ind+24])\n",
    "# 从CMIP6的15种模式种各抽取100条数据\n",
    "for model_i in range(15):\n",
    "    samples = np.random.choice(cmip6_trains.shape[1]-38, size=100)\n",
    "    for ind in samples:\n",
    "        X_train.append(cmip6_trains[model_i, ind: ind+38])\n",
    "        y_train.append(cmip6_labels[model_i][ind: ind+24])\n",
    "X_train = np.array(X_train)\n",
    "y_train = np.array(y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:12.541232Z",
     "iopub.status.busy": "2021-11-29T03:05:12.540676Z",
     "iopub.status.idle": "2021-11-29T03:05:12.548103Z",
     "shell.execute_reply": "2021-11-29T03:05:12.547520Z",
     "shell.execute_reply.started": "2021-11-27T13:39:26.341849Z"
    },
    "papermill": {
     "duration": 0.040262,
     "end_time": "2021-11-29T03:05:12.548224",
     "exception": false,
     "start_time": "2021-11-29T03:05:12.507962",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 构造测试集\n",
    "\n",
    "X_valid = []\n",
    "y_valid = []\n",
    "samples = np.random.choice(soda_trains.shape[1]-38, size=100)\n",
    "for ind in samples:\n",
    "    X_valid.append(soda_trains[0, ind: ind+38])\n",
    "    y_valid.append(soda_labels[0][ind: ind+24])\n",
    "X_valid = np.array(X_valid)\n",
    "y_valid = np.array(y_valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:12.606407Z",
     "iopub.status.busy": "2021-11-29T03:05:12.605555Z",
     "iopub.status.idle": "2021-11-29T03:05:12.611580Z",
     "shell.execute_reply": "2021-11-29T03:05:12.611152Z",
     "shell.execute_reply.started": "2021-11-27T13:39:27.247585Z"
    },
    "papermill": {
     "duration": 0.036214,
     "end_time": "2021-11-29T03:05:12.611721",
     "exception": false,
     "start_time": "2021-11-29T03:05:12.575507",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看数据集维度\n",
    "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:12.737322Z",
     "iopub.status.busy": "2021-11-29T03:05:12.736558Z",
     "iopub.status.idle": "2021-11-29T03:05:13.516712Z",
     "shell.execute_reply": "2021-11-29T03:05:13.517217Z",
     "shell.execute_reply.started": "2021-11-27T13:39:38.421657Z"
    },
    "papermill": {
     "duration": 0.812187,
     "end_time": "2021-11-29T03:05:13.517368",
     "exception": false,
     "start_time": "2021-11-29T03:05:12.705181",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 保存数据集\n",
    "np.save('X_train_sample.npy', X_train)\n",
    "np.save('y_train_sample.npy', y_train)\n",
    "np.save('X_valid_sample.npy', X_valid)\n",
    "np.save('y_valid_sample.npy', y_valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型构建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:13.577516Z",
     "iopub.status.busy": "2021-11-29T03:05:13.576992Z",
     "iopub.status.idle": "2021-11-29T03:05:21.917657Z",
     "shell.execute_reply": "2021-11-29T03:05:21.918265Z",
     "shell.execute_reply.started": "2021-11-29T01:03:01.505192Z"
    },
    "papermill": {
     "duration": 8.372964,
     "end_time": "2021-11-29T03:05:21.918443",
     "exception": false,
     "start_time": "2021-11-29T03:05:13.545479",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 读取数据集\n",
    "X_train = np.load('../input/ai-earth-task05-samples/X_train_sample.npy')\n",
    "y_train = np.load('../input/ai-earth-task05-samples/y_train_sample.npy')\n",
    "X_valid = np.load('../input/ai-earth-task05-samples/X_valid_sample.npy')\n",
    "y_valid = np.load('../input/ai-earth-task05-samples/y_valid_sample.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:21.983898Z",
     "iopub.status.busy": "2021-11-29T03:05:21.982953Z",
     "iopub.status.idle": "2021-11-29T03:05:21.986939Z",
     "shell.execute_reply": "2021-11-29T03:05:21.986453Z",
     "shell.execute_reply.started": "2021-11-29T01:03:11.548945Z"
    },
    "papermill": {
     "duration": 0.039398,
     "end_time": "2021-11-29T03:05:21.987066",
     "exception": false,
     "start_time": "2021-11-29T03:05:21.947668",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:22.341929Z",
     "iopub.status.busy": "2021-11-29T03:05:22.340932Z",
     "iopub.status.idle": "2021-11-29T03:05:22.346140Z",
     "shell.execute_reply": "2021-11-29T03:05:22.346878Z",
     "shell.execute_reply.started": "2021-11-29T01:03:11.560457Z"
    },
    "papermill": {
     "duration": 0.143838,
     "end_time": "2021-11-29T03:05:22.347113",
     "exception": false,
     "start_time": "2021-11-29T03:05:22.203275",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 构造数据管道\n",
    "class AIEarthDataset(Dataset):\n",
    "    def __init__(self, data, label):\n",
    "        self.data = torch.tensor(data, dtype=torch.float32)\n",
    "        self.label = torch.tensor(label, dtype=torch.float32)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.label)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx], self.label[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:22.583350Z",
     "iopub.status.busy": "2021-11-29T03:05:22.582298Z",
     "iopub.status.idle": "2021-11-29T03:05:23.243100Z",
     "shell.execute_reply": "2021-11-29T03:05:23.243851Z",
     "shell.execute_reply.started": "2021-11-29T01:03:23.691846Z"
    },
    "papermill": {
     "duration": 0.825537,
     "end_time": "2021-11-29T03:05:23.244098",
     "exception": false,
     "start_time": "2021-11-29T03:05:22.418561",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size = 2\n",
    "\n",
    "trainset = AIEarthDataset(X_train, y_train)\n",
    "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "validset = AIEarthDataset(X_valid, y_valid)\n",
    "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 构造评估函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:23.655820Z",
     "iopub.status.busy": "2021-11-29T03:05:23.655241Z",
     "iopub.status.idle": "2021-11-29T03:05:23.658416Z",
     "shell.execute_reply": "2021-11-29T03:05:23.658859Z",
     "shell.execute_reply.started": "2021-11-29T01:03:26.481561Z"
    },
    "papermill": {
     "duration": 0.040887,
     "end_time": "2021-11-29T03:05:23.658990",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.618103",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def rmse(y_true, y_preds):\n",
    "    return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n",
    "\n",
    "# 评估函数\n",
    "def score(y_true, y_preds):\n",
    "    # 相关性技巧评分\n",
    "    accskill_score = 0\n",
    "    # RMSE\n",
    "    rmse_scores = 0\n",
    "    a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n",
    "    y_true_mean = np.mean(y_true, axis=0)\n",
    "    y_pred_mean = np.mean(y_preds, axis=0)\n",
    "    for i in range(24):\n",
    "        fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n",
    "        fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n",
    "        cor_i = fenzi / fenmu\n",
    "        accskill_score += a[i] * np.log(i+1) * cor_i\n",
    "        rmse_score = rmse(y_true[:, i], y_preds[:, i])\n",
    "        rmse_scores += rmse_score\n",
    "    return 2/3.0 * accskill_score - rmse_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "papermill": {
     "duration": 0.028556,
     "end_time": "2021-11-29T03:05:23.310560",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.282004",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "#### 模型构造"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "不同于前两个TOP方案所构建的多输出神经网络，该TOP方案采用的是Seq2Seq结构，以本赛题为例，输入的序列长度是12，输出的序列长度是26，方案中构建了四个隐藏层，那么一个基础的Seq2Seq结构就如下图所示：\n",
    "\n",
    "<img src=\"./fig/Task5-Seq2Seq基础结构.png\" width=\"70%\">\n",
    "\n",
    "要将Seq2Seq结构应用于不同的问题，重点在于使用怎样的Cell（神经元）。在该TOP方案中使用的Cell是清华大学提出的SA-ConvLSTM（Self-Attention ConvLSTM），论文原文可参考https://ojs.aaai.org//index.php/AAAI/article/view/6819\n",
    "\n",
    "SA-ConvLSTM是施行健博士提出的时空序列领域经典模型ConvLSTM的改进模型，为了捕捉空间信息的时序依赖关系，它在ConvLSTM的基础上增加了SAM模块，用来记忆空间的聚合特征。ConvLSTM的论文原文可参考https://arxiv.org/pdf/1506.04214.pdf\n",
    "\n",
    "1. ConvLSTM模型\n",
    "\n",
    "LSTM模型是非常经典的时序模型，三个门的结构使得它在挖掘长期的时间依赖任务中有不俗的表现，并且相较于RNN，LSTM能够有效地避免梯度消失问题。对于单个输入样本，在每个时间步上，LSTM的每个门实际是对输入向量做了一个全连接，那么对应到我们这个赛题上，输入X的形状是（N,T,H,W,C），则单个输入样本在每个时间步上输入LSTM的就是形状为（H,W,C）的空间信息。我们知道，全连接网络对于这种空间信息的提取能力并不强，转换成卷积操作后能够在大大减少参数量的同时通过堆叠多层网络逐步提取出更复杂的特征，到这里就可以很自然地想到，把LSTM中的全连接操作转换为卷积操作，就能够适用于时空序列问题。ConvLSTM模型就是这么做的，实践也表明这样的作法是非常有效的。\n",
    "\n",
    "<img src=\"./fig/Task5-LSTM与ConvLSTM公式比较.png\" width=\"100%\">\n",
    "\n",
    "2. SAM模块\n",
    "\n",
    "然而，ConvLSTM模型存在两个问题：\n",
    "\n",
    "一是卷积层的感受野受限于卷积核的大小，需要通过堆叠多个卷积层来扩大感受野，发掘全局的特征。举例来说，假设第一个卷积层的卷积核大小是3×3，那么这一层的每个节点就只能感知这3×3的空间范围内的输入信息，此时再增加一个3×3的卷积层，那么每个节点所能感知的就是3×3个第一层的节点内的信息，在第一层步长为1的情况下，就是4×4范围内的输入信息，于是相比于第一个卷积层，第二层所能感知的输入信息的空间范围就增大了，而这样做所带来的后果就是参数量增加。对于单纯的CNN模型来说增加一层只是增加了一个卷积核大小的参数量，但是对于ConvLSTM来说就有些不堪重负，参数量的增加增大了过拟合的风险，与此同时模型的收效却并不高。\n",
    "\n",
    "二是卷积操作只针对当前时间步输入的空间信息，而忽视了过去的空间信息，因此难以挖掘空间信息在时间上的依赖关系。\n",
    "\n",
    "因此，为了同时挖掘全局和本地的空间依赖，提升模型在大空间范围和长时间的时空序列预测任务中的预测效果，SA-ConvLSTM模型在ConvLSTM模型的基础上引入了SAM（self-attention memory）模块。\n",
    "\n",
    "<img src=\"./fig/Task5-SAM模块.png\" width=\"50%\">\n",
    "\n",
    "SAM模块引入了一个新的记忆单元M，用来记忆包含时序依赖关系的空间信息。SAM模块以当前时间步通过ConvLSTM所获得的隐藏层状态$H_t$和上一个时间步的记忆$M_{t-1}$作为输入，首先将$H_t$通过自注意力机制得到特征$Z_h$，自注意力机制能够增加$H_t$中与其他部分更相关的部分的权重，同时$H_t$也作为Query与$M_{t-1}$共同通过注意力机制得到特征$Z_m$，用以增强对$M_{t-1}$中与$H_t$有更强依赖关系的部分的权重，将$Z_h$和$Z_m$拼接起来就得到了二者的聚合特征$Z$。此时，聚合特征$Z$中既包含了当前时间步的信息，又包含了全局的时空记忆信息，接下来借鉴LSTM中的门控结构用聚合特征$Z$对隐藏层状态和记忆单元进行更新，就得到了更新后的隐藏层状态$\\hat{H_t}$和当前时间步的记忆$M_t$。SAM模块的公式如下：\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "& i'_t = \\sigma (W_{m;zi} \\ast Z + W_{m;hi} \\ast H_t + b_{m;i}) \\\\\n",
    "& g'_t = tanh (W_{m;zg} \\ast Z + W_{m;hg} \\ast H_t + b_{m;g}) \\\\\n",
    "& M_t = (1 - i'_t) \\circ M_{t-1} + i'_t \\circ g'_t \\\\\n",
    "& o'_t = \\sigma (W_{m;zo} \\ast Z + W_{m;ho} \\ast H_t + b_{m;o}) \\\\\n",
    "& \\hat{H_t} = o'_t \\circ M_t\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "关于注意力机制和自注意力机制可以参考以下链接：\n",
    "\n",
    "   - 深度学习中的注意力机制：https://blog.csdn.net/malefactor/article/details/78767781\n",
    "   - 目前主流的Attention方法：https://www.zhihu.com/question/68482809\n",
    "\n",
    "3. SA-ConvLSTM模型\n",
    "\n",
    "将以上二者结合起来，就得到了SA-ConvLSTM模型：\n",
    "\n",
    "<img src=\"./fig/Task5-SA-ConvLSTM模型.png\" width=\"40%\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:23.372772Z",
     "iopub.status.busy": "2021-11-29T03:05:23.371873Z",
     "iopub.status.idle": "2021-11-29T03:05:23.373700Z",
     "shell.execute_reply": "2021-11-29T03:05:23.374122Z",
     "shell.execute_reply.started": "2021-11-29T01:03:24.585147Z"
    },
    "papermill": {
     "duration": 0.035787,
     "end_time": "2021-11-29T03:05:23.374254",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.338467",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Attention机制\n",
    "def attn(query, key, value):\n",
    "    # query、key、value的形状都是(N, C, H*W)，令S=H*W\n",
    "    # 采用缩放点积模型计算得分，scores(i)=key(i)^T query/根号C\n",
    "    scores = torch.matmul(query.transpose(1, 2), key / math.sqrt(query.size(1)))  # (N, S, S)\n",
    "    # 计算注意力得分\n",
    "    attn = F.softmax(scores, dim=-1)\n",
    "    output = torch.matmul(attn, value.transpose(1, 2))  # (N, S, C)\n",
    "    return output.transpose(1, 2)  # (N, C, S)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:23.440765Z",
     "iopub.status.busy": "2021-11-29T03:05:23.440042Z",
     "iopub.status.idle": "2021-11-29T03:05:23.442191Z",
     "shell.execute_reply": "2021-11-29T03:05:23.442569Z",
     "shell.execute_reply.started": "2021-11-29T01:03:25.147999Z"
    },
    "papermill": {
     "duration": 0.041095,
     "end_time": "2021-11-29T03:05:23.442725",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.401630",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# SAM模块\n",
    "class SAAttnMem(nn.Module):\n",
    "    def __init__(self, input_dim, d_model, kernel_size):\n",
    "        super().__init__()\n",
    "        pad = kernel_size[0] // 2, kernel_size[1] // 2\n",
    "        self.d_model = d_model\n",
    "        self.input_dim = input_dim\n",
    "        # 用1*1卷积实现全连接操作WhHt\n",
    "        self.conv_h = nn.Conv2d(input_dim, d_model*3, kernel_size=1)\n",
    "        # 用1*1卷积实现全连接操作WmMt-1\n",
    "        self.conv_m = nn.Conv2d(input_dim, d_model*2, kernel_size=1)\n",
    "        # 用1*1卷积实现全连接操作Wz[Zh,Zm]\n",
    "        self.conv_z = nn.Conv2d(d_model*2, d_model, kernel_size=1)\n",
    "        # 注意输出维度和输入维度要保持一致，都是input_dim\n",
    "        self.conv_output = nn.Conv2d(input_dim+d_model, input_dim*3, kernel_size=kernel_size, padding=pad)\n",
    "        \n",
    "    def forward(self, h, m):\n",
    "        # self.conv_h(h)得到WhHt，将其在dim=1上划分成大小为self.d_model的块，每一块的形状就是(N, d_model, H, W)，所得到的三块就是Qh、Kh、Vh\n",
    "        hq, hk, hv = torch.split(self.conv_h(h), self.d_model, dim=1)\n",
    "        # 同样的方法得到Km和Vm\n",
    "        mk, mv = torch.split(self.conv_m(m), self.d_model, dim=1)\n",
    "        N, C, H, W = hq.size()\n",
    "        # 通过自注意力机制得到Zh\n",
    "        Zh = attn(hq.view(N, C, -1), hk.view(N, C, -1), hv.view(N, C, -1))  # (N, C, S), C=d_model\n",
    "        # 通过注意力机制得到Zm\n",
    "        Zm = attn(hq.view(N, C, -1), mk.view(N, C, -1), mv.view(N, C, -1))  # (N, C, S), C=d_model\n",
    "        # 将Zh和Zm拼接起来，并进行全连接操作得到聚合特征Z\n",
    "        Z = self.conv_z(torch.cat([Zh.view(N, C, H, W), Zm.view(N, C, H, W)], dim=1))  # (N, C, H, W), C=d_model\n",
    "        # 计算i't、g't、o't\n",
    "        i, g, o = torch.split(self.conv_output(torch.cat([Z, h], dim=1)), self.input_dim, dim=1)  # (N, C, H, W), C=input_dim\n",
    "        i = torch.sigmoid(i)\n",
    "        g = torch.tanh(g)\n",
    "        # 得到更新后的记忆单元Mt\n",
    "        m_next = i * g + (1 - i) * m\n",
    "        # 得到更新后的隐藏状态Ht\n",
    "        h_next = torch.sigmoid(o) * m_next\n",
    "        return h_next, m_next"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:23.509738Z",
     "iopub.status.busy": "2021-11-29T03:05:23.509080Z",
     "iopub.status.idle": "2021-11-29T03:05:23.512667Z",
     "shell.execute_reply": "2021-11-29T03:05:23.512182Z",
     "shell.execute_reply.started": "2021-11-29T01:03:25.667808Z"
    },
    "papermill": {
     "duration": 0.042616,
     "end_time": "2021-11-29T03:05:23.512781",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.470165",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# SA-ConvLSTM Cell\n",
    "class SAConvLSTMCell(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n",
    "        super().__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        pad = kernel_size[0] // 2, kernel_size[1] // 2\n",
    "        # 卷积操作Wx*Xt+Wh*Ht-1\n",
    "        self.conv = nn.Conv2d(in_channels=input_dim+hidden_dim, out_channels=4*hidden_dim, kernel_size=kernel_size, padding=pad)\n",
    "        self.sa = SAAttnMem(input_dim=hidden_dim, d_model=d_attn, kernel_size=kernel_size)\n",
    "        \n",
    "    def initialize(self, inputs):\n",
    "        device = inputs.device\n",
    "        N, _, H, W = inputs.size()\n",
    "        # 初始化隐藏层状态Ht\n",
    "        self.hidden_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
    "        # 初始化记忆细胞状态ct\n",
    "        self.cell_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
    "        # 初始化记忆单元状态Mt\n",
    "        self.memory_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
    "    \n",
    "    def forward(self, inputs, first_step=False):\n",
    "        # 如果当前是第一个时间步，初始化Ht、ct、Mt\n",
    "        if first_step:\n",
    "            self.initialize(inputs)\n",
    "        \n",
    "        # ConvLSTM部分\n",
    "        # 拼接Xt和Ht\n",
    "        combined = torch.cat([inputs, self.hidden_state], dim=1)  # (N, C, H, W), C=input_dim+hidden_dim\n",
    "        # 进行卷积操作\n",
    "        combined_conv = self.conv(combined)       \n",
    "        # 得到四个门控单元it、ft、ot、gt\n",
    "        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)\n",
    "        i = torch.sigmoid(cc_i)\n",
    "        f = torch.sigmoid(cc_f)\n",
    "        o = torch.sigmoid(cc_o)\n",
    "        g = torch.tanh(cc_g)\n",
    "        # 得到当前时间步的记忆细胞状态ct=ft·ct-1+it·gt\n",
    "        self.cell_state = f * self.cell_state + i * g\n",
    "        # 得到当前时间步的隐藏层状态Ht=ot·tanh(ct)\n",
    "        self.hidden_state = o * torch.tanh(self.cell_state)\n",
    "        \n",
    "        # SAM部分，更新Ht和Mt\n",
    "        self.hidden_state, self.memory_state = self.sa(self.hidden_state, self.memory_state)\n",
    "        \n",
    "        return self.hidden_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在Seq2Seq模型的训练中，有两种训练模式。一是Free running，也就是传统的训练方式，以上一个时间步的输出$\\hat{y_{t-1}}$作为下一个时间步的输入，但是这种做法存在的问题是在训练的初期所得到的$\\hat{y_{t-1}}$与实际标签$y_{t-1}$相差甚远，以此作为输入会导致后续的输出越来越偏离我们期望的预测标签。于是就产生了第二种训练模式——Teacher forcing。\n",
    "\n",
    "Teacher forcing就是直接使用实际标签$y_{t-1}$作为下一个时间步的输入，由老师（ground truth）带领着防止模型越走越偏。但是老师不能总是手把手领着学生走，要逐渐放手让学生自主学习，于是我们使用Scheduled Sampling来控制使用实际标签的概率。我们用ratio来表示Scheduled Sampling的比例，在训练初期，ratio=1，模型完全由老师带领着，随着训练论述的增加，ratio以一定的方式衰减（该方案中使用线性衰减，ratio每次减小一个衰减率decay_rate），每个时间步以ratio的概率从伯努利分布中提取二进制随机数0或1，为1时输入就是实际标签$y_{t-1}$，否则输入为$\\hat{y_{t-1}}$。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:23.587567Z",
     "iopub.status.busy": "2021-11-29T03:05:23.586781Z",
     "iopub.status.idle": "2021-11-29T03:05:23.588776Z",
     "shell.execute_reply": "2021-11-29T03:05:23.589156Z",
     "shell.execute_reply.started": "2021-11-29T01:03:26.065997Z"
    },
    "papermill": {
     "duration": 0.047514,
     "end_time": "2021-11-29T03:05:23.589277",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.541763",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 构建SA-ConvLSTM模型\n",
    "class SAConvLSTM(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n",
    "        super().__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.num_layers = len(hidden_dim)\n",
    "        \n",
    "        layers = []\n",
    "        for i in range(self.num_layers):\n",
    "            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]\n",
    "            layers.append(SAConvLSTMCell(input_dim=cur_input_dim, hidden_dim=self.hidden_dim[i], d_attn = d_attn, kernel_size=kernel_size))            \n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        \n",
    "        self.conv_output = nn.Conv2d(self.hidden_dim[-1], 1, kernel_size=1)\n",
    "        \n",
    "    def forward(self, input_x, device=torch.device('cuda:0'), input_frames=12, future_frames=26, output_frames=37, teacher_forcing=False, scheduled_sampling_ratio=0, train=True):\n",
    "        # 将输入样本X的形状(N, T, H, W, C)转换为(N, T, C, H, W)\n",
    "        input_x = input_x.permute(0, 1, 4, 2, 3).contiguous()\n",
    "        \n",
    "        # 仅在训练时使用teacher forcing\n",
    "        if train:\n",
    "            if teacher_forcing and scheduled_sampling_ratio > 1e-6:\n",
    "                teacher_forcing_mask = torch.bernoulli(scheduled_sampling_ratio * torch.ones(input_x.size(0), future_frames-1, 1, 1, 1))\n",
    "            else:\n",
    "                teacher_forcing = False\n",
    "        else:\n",
    "            teacher_forcing = False\n",
    "            \n",
    "        total_steps = input_frames + future_frames - 1\n",
    "        outputs = [None] * total_steps\n",
    "        \n",
    "        # 对于每一个时间步\n",
    "        for t in range(total_steps):\n",
    "            # 在前12个月，使用每个月的输入样本Xt\n",
    "            if t < input_frames:\n",
    "                input_ = input_x[:, t].to(device)\n",
    "            # 若不使用teacher forcing，则以上一个时间步的预测标签作为当前时间步的输入\n",
    "            elif not teacher_forcing:\n",
    "                input_ = outputs[t-1]\n",
    "            # 若使用teacher forcing，则以ratio的概率使用上一个时间步的实际标签作为当前时间步的输入\n",
    "            else:\n",
    "                mask = teacher_forcing_mask[:, t-input_frames].float().to(device)\n",
    "                input_ = input_x[:, t].to(device) * mask + outputs[t-1] * (1-mask)\n",
    "            first_step = (t==0)\n",
    "            input_ = input_.float()\n",
    "            \n",
    "            # 将当前时间步的输入通过隐藏层\n",
    "            for layer_idx in range(self.num_layers):\n",
    "                input_ = self.layers[layer_idx](input_, first_step=first_step)\n",
    "            \n",
    "            # 记录每个时间步的输出\n",
    "            if train or (t >= (input_frames - 1)):\n",
    "                outputs[t] = self.conv_output(input_)\n",
    "                \n",
    "        outputs = [x for x in outputs if x is not None]\n",
    "        \n",
    "        # 确认输出序列的长度\n",
    "        if train:\n",
    "            assert len(outputs) == output_frames\n",
    "        else:\n",
    "            assert len(outputs) == future_frames\n",
    "        \n",
    "        # 得到sst的预测序列\n",
    "        outputs = torch.stack(outputs, dim=1)[:, :, 0]  # (N, 37, H, W)\n",
    "        # 对sst的预测序列在nino3.4区域取三个月的平均值就得到nino3.4指数的预测序列\n",
    "        nino_pred = outputs[:, -future_frames:, 10:13, 19:30].mean(dim=[2, 3])  # (N, 26)\n",
    "        nino_pred = nino_pred.unfold(dimension=1, size=3, step=1).mean(dim=2)  # (N, 24)\n",
    "        \n",
    "        return nino_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:23.726291Z",
     "iopub.status.busy": "2021-11-29T03:05:23.725688Z",
     "iopub.status.idle": "2021-11-29T03:05:23.753509Z",
     "shell.execute_reply": "2021-11-29T03:05:23.753976Z",
     "shell.execute_reply.started": "2021-11-29T01:03:29.448921Z"
    },
    "papermill": {
     "duration": 0.066105,
     "end_time": "2021-11-29T03:05:23.754109",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.688004",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SAConvLSTM(\n",
      "  (layers): ModuleList(\n",
      "    (0): SAConvLSTMCell(\n",
      "      (conv): Conv2d(65, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (sa): SAAttnMem(\n",
      "        (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      )\n",
      "    )\n",
      "    (1): SAConvLSTMCell(\n",
      "      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (sa): SAAttnMem(\n",
      "        (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      )\n",
      "    )\n",
      "    (2): SAConvLSTMCell(\n",
      "      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (sa): SAAttnMem(\n",
      "        (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      )\n",
      "    )\n",
      "    (3): SAConvLSTMCell(\n",
      "      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (sa): SAAttnMem(\n",
      "        (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "        (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (conv_output): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 输入特征数\n",
    "input_dim = 1\n",
    "# 隐藏层节点数\n",
    "hidden_dim = (64, 64, 64, 64)\n",
    "# 注意力机制节点数\n",
    "d_attn = 32\n",
    "# 卷积核大小\n",
    "kernel_size = (3, 3)\n",
    "\n",
    "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:23.816671Z",
     "iopub.status.busy": "2021-11-29T03:05:23.815927Z",
     "iopub.status.idle": "2021-11-29T03:05:23.818479Z",
     "shell.execute_reply": "2021-11-29T03:05:23.818058Z",
     "shell.execute_reply.started": "2021-11-29T01:03:31.476806Z"
    },
    "papermill": {
     "duration": 0.035723,
     "end_time": "2021-11-29T03:05:23.818579",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.782856",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 采用RMSE作为损失函数\n",
    "def RMSELoss(y_pred,y_true):\n",
    "    loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T03:05:23.893469Z",
     "iopub.status.busy": "2021-11-29T03:05:23.892684Z",
     "iopub.status.idle": "2021-11-29T04:55:28.956056Z",
     "shell.execute_reply": "2021-11-29T04:55:28.956434Z"
    },
    "papermill": {
     "duration": 6605.109145,
     "end_time": "2021-11-29T04:55:28.956614",
     "exception": false,
     "start_time": "2021-11-29T03:05:23.847469",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1/5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1600/1600 [21:43<00:00,  1.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Loss: 3.289\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:11,  4.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Loss: 44.009\n",
      "Score: -43.458\n",
      "Epoch: 2/5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1600/1600 [21:43<00:00,  1.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Loss: 3.084\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:11,  4.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Loss: 25.011\n",
      "Score: -19.966\n",
      "Epoch: 3/5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1600/1600 [21:46<00:00,  1.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Loss: 13.461\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:12,  4.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Loss: 15.438\n",
      "Score: -14.139\n",
      "Epoch: 4/5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1600/1600 [21:54<00:00,  1.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Loss: 17.627\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:12,  3.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Loss: 15.389\n",
      "Score: -22.500\n",
      "Epoch: 5/5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1600/1600 [21:55<00:00,  1.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Loss: 17.592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:11,  4.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Loss: 15.252\n",
      "Score: -14.459\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model_weights = './task05_model_weights.pth'\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size).to(device)\n",
    "criterion = RMSELoss\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=0, verbose=True, min_lr=0.0001)\n",
    "epochs = 5\n",
    "ratio, decay_rate = 1, 8e-5\n",
    "train_losses, valid_losses = [], []\n",
    "scores = []\n",
    "best_score = float('-inf')\n",
    "preds = np.zeros((len(y_valid),24))\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print('Epoch: {}/{}'.format(epoch+1, epochs))\n",
    "    \n",
    "    # 模型训练\n",
    "    model.train()\n",
    "    losses = 0\n",
    "    for data, labels in tqdm(trainloader):\n",
    "        data = data.to(device)\n",
    "        labels = labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        # ratio线性衰减\n",
    "        ratio = max(ratio-decay_rate, 0)\n",
    "        pred = model(data, teacher_forcing=True, scheduled_sampling_ratio=ratio, train=True)\n",
    "        loss = criterion(pred, labels)\n",
    "        losses += loss.cpu().detach().numpy()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    train_loss = losses / len(trainloader)\n",
    "    train_losses.append(train_loss)\n",
    "    print('Training Loss: {:.3f}'.format(train_loss))\n",
    "    \n",
    "    # 模型验证\n",
    "    model.eval()\n",
    "    losses = 0\n",
    "    with torch.no_grad():\n",
    "        for i, data in tqdm(enumerate(validloader)):\n",
    "            data, labels = data\n",
    "            data = data.to(device)\n",
    "            labels = labels.to(device)\n",
    "            pred = model(data, train=False)\n",
    "            loss = criterion(pred, labels)\n",
    "            losses += loss.cpu().detach().numpy()\n",
    "            preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
    "    valid_loss = losses / len(validloader)\n",
    "    valid_losses.append(valid_loss)\n",
    "    print('Validation Loss: {:.3f}'.format(valid_loss))\n",
    "    s = score(y_valid, preds)\n",
    "    scores.append(s)\n",
    "    print('Score: {:.3f}'.format(s))\n",
    "    \n",
    "    # 保存最佳模型权重\n",
    "    if s > best_score:\n",
    "        best_score = s\n",
    "        checkpoint = {'best_score': s,\n",
    "                      'state_dict': model.state_dict()}\n",
    "        torch.save(checkpoint, model_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T04:55:33.957872Z",
     "iopub.status.busy": "2021-11-29T04:55:33.957066Z",
     "iopub.status.idle": "2021-11-29T04:55:33.960119Z",
     "shell.execute_reply": "2021-11-29T04:55:33.959684Z",
     "shell.execute_reply.started": "2021-11-28T14:00:36.33194Z"
    },
    "papermill": {
     "duration": 2.38263,
     "end_time": "2021-11-29T04:55:33.960247",
     "exception": false,
     "start_time": "2021-11-29T04:55:31.577617",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 绘制训练/验证曲线\n",
    "def training_vis(train_losses, valid_losses):\n",
    "    # 绘制损失函数曲线\n",
    "    fig = plt.figure(figsize=(8,4))\n",
    "    # subplot loss\n",
    "    ax1 = fig.add_subplot(121)\n",
    "    ax1.plot(train_losses, label='train_loss')\n",
    "    ax1.plot(valid_losses,label='val_loss')\n",
    "    ax1.set_xlabel('Epochs')\n",
    "    ax1.set_ylabel('Loss')\n",
    "    ax1.set_title('Loss on Training and Validation Data')\n",
    "    ax1.legend()\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T04:55:38.227343Z",
     "iopub.status.busy": "2021-11-29T04:55:38.226636Z",
     "iopub.status.idle": "2021-11-29T04:55:38.470256Z",
     "shell.execute_reply": "2021-11-29T04:55:38.469252Z",
     "shell.execute_reply.started": "2021-11-28T14:00:43.42651Z"
    },
    "papermill": {
     "duration": 2.378943,
     "end_time": "2021-11-29T04:55:38.470387",
     "exception": false,
     "start_time": "2021-11-29T04:55:36.091444",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsYUlEQVR4nO3dd3wUdf7H8dcnhSRApIYSeu8QMKKIngoWRATPE1AsgIVT0cMu3umpWE899TwVT08BFQsWDsQu4A89UQiYUKRzlBBCDwQlkPL5/TETWCCBhOxkspvP8/HYR2ZnZmfeu5l8MvPd78yIqmKMMaEmwu8AxhhzIqx4GWNCkhUvY0xIsuJljAlJVryMMSHJipcxJiRZ8aqkRORzERke7Hn9JCLrRORcD5b7rYhc7w5fKSJflWTeE1hPUxHZKyKRJ5q1MqlUxcurjbu8uBt24aNARPYFPL+yNMtS1QtVdVKw562IRGSsiMwpYnxdETkgIp1LuixVnayq5wcp12Hbo6puUNXqqpofjOUfsS4VkV/dbWWHiMwUkaGleP3ZIpIe7FxlUamKV6hzN+zqqlod2ABcHDBucuF8IhLlX8oK6W3gdBFpccT4y4HFqrrEh0x+6OZuO+2AicCLIvKgv5FOnBUvQERiROR5EclwH8+LSIw7ra6IzBCRLBHZKSLfiUiEO+1eEdkkItkiskJE+haz/Boi8qaIbBOR9SJyf8AyRojI9yLyjIjsEpH/iciFpcx/toiku3kygQkiUsvNvc1d7gwRaRzwmsBDoWNmKOW8LURkjvuZfCMiL4nI28XkLknGR0Tkv+7yvhKRugHTr3Y/zx0i8pfiPh9VTQdmAVcfMeka4M3j5Tgi8wgR+T7g+XkislxEdovIi4AETGslIrPcfNtFZLKI1HSnvQU0BT5x94buEZHm7h5SlDtPoohMd7e71SJyQ8CyHxKRKe52lS0iS0UkubjP4IjPY7uqvgXcBNwnInXcZY4UkWXu8taKyB/d8dWAz4FEObSnnygiPUVkrvu3sVlEXhSRKiXJEAxWvBx/AU4DkoBuQE/gfnfanUA6kADUB/4MqIi0A24BTlHVeOACYF0xy/8nUANoCZyF80czMmD6qcAKoC7wFPC6iMiRCzmOBkBtoBkwCud3O8F93hTYB7x4jNeXJsOx5n0HmAfUAR7i6IIRqCQZh+F8VvWAKsBdACLSERjvLj/RXV+RBcc1KTCL+/tLcvOW9rMqXEZd4GOcbaUusAboHTgL8ISbrwPQBOczQVWv5vC956eKWMV7ONteInAZ8LiI9AmYPtCdpyYwvSSZjzANiMLZ3gG2AgOAk3A+8+dEpIeq/gpcCGQE7OlnAPnA7e577wX0BW4uZYYTp6qV5oFTXM4tYvwaoH/A8wuAde7wOJxfcusjXtMa55d9LhB9jHVGAgeAjgHj/gh86w6PAFYHTKsKKNCgpO8FONtdR+wx5k8CdgU8/xa4viQZSjovzh9+HlA1YPrbwNsl/P0UlfH+gOc3A1+4w38F3guYVs39DI76/Qbk3AOc7j5/DJh2gp/V9+7wNcCPAfMJTrG5vpjlXgL8XNz2CDR3P8sonEKXD8QHTH8CmOgOPwR8EzCtI7DvGJ+tcsQ27I7PBK4s5jX/AcYEbGPpx/n93QZMLcnvOhgP2/NyJALrA56vd8cBPA2sBr5yd6XHAqjqapxf1kPAVhF5T0QSOVpdILqI5TcKeJ5ZOKCqv7mD1Uv5Hrapak7hExGpKiL/cg+r9gBzgJpS/DdZpclQ3LyJwM6AcQAbiwtcwoyZAcO/BWRKDFy2OnsHO4pbl5vpA+Aady/xSuDNUuQoypEZNPC5iNR3t4tN7nLfxtkeSqLws8wOGFfsdoPz2cRKKdo7RSQa54hip/v8QhH50T1MzQL6HyuviLR1D7Ez3ff3+LHmDzYrXo4MnEOGQk3dcahqtqreqaotcXbT7xC3bUtV31HVM9zXKvC3Ipa9HcgtYvmbgvwejrw8yJ04DbOnqupJwO/c8aU9HC2NzUBtEakaMK7JMeYvS8bNgct211nnOK+ZBAwBzgPigU/KmOPIDMLh7/dxnN9LF3e5Vx2xzGNd0iUD57OMDxgX7O1mEM6e8jxx2ng/Ap4B6qtqTeCzgLxFZR0PLAfauO/vz3i7fR2mMhavaBGJDXhEAe8C94tIgtuO8Vec/5KIyAARae1umLtxduULRKSdiPRxf+k5OO0kBUeuTJ2vvacAj4lIvIg0A+4oXL6H4t1MWSJSG/D8WyVVXQ+kAA+JSBUR6QVc7FHGD4EBInKG20g8juNvz98BWcCrOIecB8qY41Ogk4hc6m5Hf8I5fC4UD+wFdotII+DuI16/Bacd9CiquhH4AXjC3U67AtcRhO1GRGqL07XmJeBvqroDpz0xBtgG5InzJUxgl5AtQB0RqXHE+9sD7BWR9jhfAJSbyli8PsPZUAsfDwGP4vzRLQIWAwvdcQBtgG9wNsK5wMuqOhvnF/0kzp5VJk6D8n3FrPNW4FdgLfA9TiPxG8F9W0d5Hohz8/0IfOHx+gpdidN4uwPnM3wf2F/MvM9zghlVdSkwGuez3AzswmlvOtZrFOdQsZn7s0w5VHU7MBhnO9iBs638N2CWh4EeOP/0PsVp3A/0BM4/zSwRuauIVVyB0w6WAUwFHlTVb0qSrRhpIrIXpxnkeuB2Vf2r+16ycYrvFJzPchjOlwCF73U5zj/5tW7eRJwvT4YB2cBrOL/rciNuQ5sxnhCR94Hlqhqy/YlMxVQZ97yMh0TkFHH6N0WISD+cdpX/+BzLhCHriW2CrQHO4VEdnMO4m1T1Z38jmXBkh43GmJBkh43GmJAUEoeNdevW1ebNm/sdwxhTzhYsWLBdVROKmhYSxat58+akpKT4HcMYU85EZH1x0+yw0RgTkqx4GWNCkhUvY0xICok2L2MqotzcXNLT08nJyTn+zOaYYmNjady4MdHR0SV+jRUvY05Qeno68fHxNG/enNJfO9IUUlV27NhBeno6LVoceaXu4tlhozEnKCcnhzp16ljhKiMRoU6dOqXeg7XiZUwZWOEKjhP5HMOreOXsgR/Hg53yZEzYC6/itXwGfDEWUicff15jTEgLr+LV9XJo2gu+uh9+3e53GmM8lZWVxcsvv1zq1/Xv35+srKxSv27EiBF8+OGHpX6dV8KreEVEwIDnYf9ep4AZE8aKK155eXnHfN1nn31GzZo1PUpVfsKvq0S99tB7DHz3DHS7Alqe5XciUwk8/MlSfsnYE9Rldkw8iQcv7lTs9LFjx7JmzRqSkpKIjo4mNjaWWrVqsXz5clauXMkll1zCxo0bycnJYcyYMYwaNQo4dK7w3r17ufDCCznjjDP44YcfaNSoEdOmTSMuLu642WbOnMldd91FXl4ep5xyCuPHjycmJoaxY8cyffp0oqKiOP/883nmmWf44IMPePjhh4mMjKRGjRrMmTMnKJ9PeO15FfrdXVC7Jcy4HXKtA6EJT08++SStWrUiNTWVp59+moULF/KPf/yDlStXAvDGG2+wYMECUlJSeOGFF9ix4+g7w61atYrRo0ezdOlSatasyUcffXTc9ebk5DBixAjef/99Fi9eTF5eHuPHj2fHjh1MnTqVpUuXsmjRIu6/3zn6GTduHF9++SVpaWlMnz79OEsvufDb8wKIjoOLnoW3LoHv/g59ir0TvDFBcaw9pPLSs2fPwzp5vvDCC0ydOhWAjRs3smrVKurUOfzucC1atCApKQmAk08+mXXr1h13PStWrKBFixa0bdsWgOHDh/PSSy9xyy23EBsby3XXXceAAQMYMGAAAL1792bEiBEMGTKESy+9NAjv1BGee14Arc6BrkPh++dg2wq/0xjjuWrVqh0c/vbbb/nmm2+YO3cuaWlpdO/evchOoDExMQeHIyMjj9tedixRUVHMmzePyy67jBkzZtCvXz8AXnnlFR599FE2btzIySefXOQe4IkI3+IFcP5jUKUafHIbFBx1S0VjQlp8fDzZ2dlFTtu9eze1atWiatWqLF++nB9//DFo623Xrh3r1q1j9erVALz11lucddZZ7N27l927d9O/f3+ee+450tLSAFizZg2nnnoq48aNIyEhgY0bi72JeqmE52FjoeoJcP4jMP1Wp+9Xj6v9TmRM0NSpU4fevXvTuXNn4uLiqF+//sFp/fr145VXXqFDhw60a9eO0047LWjrjY2NZcKECQwePPhgg/2NN97Izp07GTRoEDk5Oagqzz77LAB33303q1atQlXp27cv3bp1C0qOkLgBR3Jysp7wlVRVYeJFsGUp3JLiFDRjgmDZsmV06NDB7xhho6jPU0QWqGpyUfOH92EjgAgMeA4O/ApfWcO9MeEi/IsXQEI7OON2WPQ+rJntdxpjKrTRo0eTlJR02GPChAl+xzpKeLd5BTrzTljyIXx6B9z0g9OdwhhzlJdeesnvCCVSOfa8AKJjncPHnWthzjN+pzHGlFHlKV4ALc92Thn67z9g6zK/0xhjyqByFS+A8x+FmOrOqUPW98uYkFX5ile1uk4B2zAXfn7L7zTGmBPkefESkUgR+VlEZrjPW4jITyKyWkTeF5EqXmc4StKV0OwM+PoB2Lu13FdvjB+qV69e7LR169bRuXPnckxTduWx5zUGCGxg+hvwnKq2BnYB15VDhsMV9v3K3Qdf/rncV2+MKTtPu0qISGPgIuAx4A5xrrLfBxjmzjIJeAgY72WOIiW0hTPugP970mnEb9233COYMPL5WMhcHNxlNugCFz5Z7OSxY8fSpEkTRo8eDcBDDz1EVFQUs2fPZteuXeTm5vLoo48yaNCgUq02JyeHm266iZSUFKKionj22Wc555xzWLp0KSNHjuTAgQMUFBTw0UcfkZiYyJAhQ0hPTyc/P58HHniAoUOHlultl5TXe17PA/cAhS3jdYAsVS08dT0daFTUC0VklIikiEjKtm3bvEl3xu1Qp7XT9yt3nzfrMMYjQ4cOZcqUKQefT5kyheHDhzN16lQWLlzI7NmzufPOOyntKYAvvfQSIsLixYt59913GT58ODk5ObzyyiuMGTOG1NRUUlJSaNy4MV988QWJiYmkpaWxZMmSg1eSKA+e7XmJyABgq6ouEJGzS/t6VX0VeBWccxuDm85V2Pdr0sUw52no+1dPVmMqgWPsIXmle/fubN26lYyMDLZt20atWrVo0KABt99+O3PmzCEiIoJNmzaxZcsWGjRoUOLlfv/999x6660AtG/fnmbNmrFy5Up69erFY489Rnp6Opdeeilt2rShS5cu3Hnnndx7770MGDCAM88806u3exQv97x6AwNFZB3wHs7h4j+AmiJSWDQbA5s8zHB8LX4H3YY5fb+2/OJrFGNKa/DgwXz44Ye8//77DB06lMmTJ7Nt2zYWLFhAamoq9evXL/XNXIszbNgwpk+fTlxcHP3792fWrFm0bduWhQsX0qVLF+6//37GjRsXlHWVhGfFS1XvU9XGqtocuByYpapXArOBy9zZhgPTvMpQYuc/CjEnwYzbrO+XCSlDhw7lvffe48MPP2Tw4MHs3r2bevXqER0dzezZs1m/fn2pl3nmmWcyebJz+8CVK1eyYcMG2rVrx9q1a2nZsiV/+tOfGDRoEIsWLSIjI4OqVaty1VVXcffdd7Nw4cJgv8Vi+dHP616cxvvVOG1gr/uQ4XDV6sAFj8HGn2DhJL/TGFNinTp1Ijs7m0aNGtGwYUOuvPJKUlJS6NKlC2+++Sbt27cv9TJvvvlmCgoK6NKlC0OHDmXixInExMQwZcoUOnfuTFJSEkuWLOGaa65h8eLF9OzZk6SkJB5++OGD160vD+F/Pa+SUnXavjIXwej5EF//+K8xlZpdzyu47HpeJ+qwvl/3+Z3GGHMcleeSOCVRtw2ceRd8+7jTiN/mXL8TGRNUixcv5uqrD78cekxMDD/99JNPiU6cFa8jnXEbLP7A6ft1849QparfiUwFpqo4fa9DQ5cuXUhNTfU7xlFOpPnKDhuPFBUDFz8PWethzlN+pzEVWGxsLDt27DihPzxziKqyY8cOYmNjS/U62/MqSvMzIOkq+OGf0GUw1Pf/hqKm4mncuDHp6el4dgZIJRIbG0vjxo1L9RorXsU5/xFY+blzz8drv4QI20k1h4uOjj7sDtWmfNlfZHGq1oYLHof0ebCg4t18wJjKzorXsXQd6pw+9M3DkJ3pdxpjTAArXsciAhc9B3k58IX1/TKmIrHidTx1W8Pv7oKlH8Oqr/1OY4xxWfEqid5joG47mHGHc+dtY4zvrHiVRFSMc+rQ7g3wf3/zO40xBiteJde8N3S/Gn54ETKX+J3GmErPildpnDcO4mrBJ2OgIN/vNMZUala8SqOw79emFEh5w+80xlRqVrxKq+sQaHk2zBwHezb7ncaYSsuKV2mJwEXPQt5++GKs32mMqbSseJ2IOq3grLvhl//Ayi/9TmNMpWTF60SdPgYS2sOnd1nfL2N8YMXrREVVgQHPO32/vn3C7zTGVDpWvMqiWS/oMRzmvgybF/mdxphKxYpXWZ37kNOFYsZt1vfLmHJkxausqtaGC56ATQtgvv+3oDSmsrDiFQxdLoNWfdy+Xxl+pzGmUrDiFQwicNHfoSAXPr/X7zTGVApWvIKldks46x5YNh1WfO53GmPCnhWvYOp1KyR0gM/uhv17/U5jTFiz4hVMUVWcez7u3mh9v4zxmBWvYGt6Gpw8En58GTan+Z3GmLBlxcsL5z4IVevadb+M8ZAVLy/E1YJ+T0DGzzDvNb/TGBOWrHh5pfMfoFVfmPUI7N7kdxpjwo4VL6+IwIBnncPGz+/xO40xYceKl5dqNYez74XlM2D5p36nMSasWPHyWq9boF5Ht+9Xtt9pjAkbVry8FhkNF/8D9myC2Y/7ncaYsGHFqzw06QnJ18JPrzjfQBpjysyKV3np+yBUS3D6fuXn+Z3GmJBnxau8xNWEfk86ve7nW98vY8rKs+IlIrEiMk9E0kRkqYg87I5vISI/ichqEXlfRKp4laHC6fR7aH0ezHoUdqf7ncaYkOblntd+oI+qdgOSgH4ichrwN+A5VW0N7AKu8zBDxSICFz3j9P36zPp+GVMWnhUvdRReFybafSjQB/jQHT8JuMSrDBVSreZwzn2w4lNYNsPvNMaELE/bvEQkUkRSga3A18AaIEtVC1us04FGXmaokE67Gep3tr5fxpSBp8VLVfNVNQloDPQE2pf0tSIySkRSRCRl27ZtXkX0R2S0c8/H7M1O+5cxptTK5dtGVc0CZgO9gJoiEuVOagwUedayqr6qqsmqmpyQkFAeMctXk1PglOvgp385dx4yxpSKl982JohITXc4DjgPWIZTxC5zZxsOTPMqQ4XX969QvT58cpv1/TKmlLzc82oIzBaRRcB84GtVnQHcC9whIquBOkDlvdlhbA248G+QuQjm/cvvNMaElKjjz3JiVHUR0L2I8Wtx2r8MQMdB0OYCmPUYdBgINZv4nciYkGA97P0mAv2fBtT59lHV70TGhAQrXhVBrWZw9n2w8nNY9onfaYwJCVa8KorTbob6XZyrrubs8TuNMRWeFa+KIjLKue5Xdqb1/TKmBKx4VSSNT4aeN8C8VyHd+n4ZcyxWvCqaPvdDfAO77pcxx2HFq6Ip7Pu1ZTH8NN7vNMZUWFa8KqIOA6Hthc4177M2+J3GmArJildFdLDvl8Cnd1nfL2OKYMWroqrZBM75M6z6En6pvKd/GlMcK14V2ak3QoOu8Pm9kLPb7zTGVChWvCqywr5fv26FmY/4ncaYCsWKV0XXqAf0HAXz/w3pKX6nMabCsOIVCs75C8Q3dPt+5fqdxpgKwYpXKIg9Cfo/BVuWwI8v+53GmArBileoaD8A2vWH2U/ArvV+pzHGd1a8QkVh3y+JgM+s75cxVrxCSY3GzrmPq76CpVP9TmOMr0pUvESkmohEuMNtRWSgiER7G80UqecoaNjN6fu1fZXfaYzxTUn3vOYAsSLSCPgKuBqY6FUocwyRUTDoZdACeK0PrPzS70TG+KKkxUtU9TfgUuBlVR0MdPIuljmmBp1h1LdQqzm8MxTmPGNtYKbSKXHxEpFewJXAp+64SG8imRKp2QSu/RK6DIZZj8AHw2H/Xr9TGVNuSlq8bgPuA6aq6lIRaYlz81jjpypV4dJX4fxHnRt3vH4e7FzrdypjyoVoKQ833Ib76qpabneJSE5O1pQUOzXmmNbMgg9GOsODJ0CrPv7mMSYIRGSBqiYXNa2k3za+IyIniUg1YAnwi4jcHcyQpoxa9XHawU5qBG//Af77grWDmbBW0sPGju6e1iXA50ALnG8cTUVSuwVc9xV0uBi+fgA+vgEO/OZ3KmM8UdLiFe3267oEmK6quYD9W6+IYqrD4EnQ96+w+EN44wK7lLQJSyUtXv8C1gHVgDki0gywO6NWVCJw5p0wbIpzHuSrZ8P/vvM7lTFBVaLipaovqGojVe2vjvXAOR5nM2XV9ny4YRZUrQtvDoKf/mXtYCZslLTBvoaIPCsiKe7j7zh7Yaaiq9sarv8G2l4An98D00ZDbo7fqYwps5IeNr4BZAND3MceYIJXoUyQxZ4EQyfDWWMhdTJMuBB2b/I7lTFlUtLi1UpVH1TVte7jYaCll8FMkEVEwDn3OUVs+0qnHWzDj36nMuaElbR47RORMwqfiEhvYJ83kYynOgyA62c630pOHAApb/idyJgTElXC+W4E3hSRGu7zXcBwbyIZz9VrDzfMho+uhxm3w+Y0uPBpiKridzJjSqyk3zamqWo3oCvQVVW7A3b+SSiLqwnD3ocz7oAFE2HSAMjO9DuVMSVWqiupquqegHMa7/AgjylPEZFw7oNw2QTIXOy0g6Uv8DuVMSVSlstAS9BSGH91vhSu+xoiq8CEfvDz234nMua4ylK8rLdjOCm8wGHTXk5fsM/usXtEmgrtmA32IpJN0UVKgDhPEhn/VK0NV30M3zwIc1+ELUthyCSoVtfvZMYc5Zh7Xqoar6onFfGIV9WSflNpQklkFFzwGPz+VdiU4rSDZaT6ncqYo3h26zMRaSIis0XkFxFZKiJj3PG1ReRrEVnl/qzlVQZTBt2GwrVfOOdCvnEBLJridyJjDuPlfRvzgDtVtSNwGjBaRDoCY4GZqtoGmOk+NxVRYnenHazRyc61wb78C+Tn+Z3KGMDD4qWqm1V1oTucDSwDGgGDgEnubJNwrhFmKqrqCXDNNOd+kXNfhMl/gN92+p3KmPK5Y7aINAe6Az8B9VV1szspE6hfzGtGFV7FYtu2beUR0xQnMhr6Pw0DX4T1PzjtYJlL/E5lKjnPi5eIVAc+Am478qYd6tz9o8guF6r6qqomq2pyQkKC1zFNSfS4GkZ+DvkHnDsVLf2P34lMJeZp8XIvHf0RMFlVP3ZHbxGRhu70hsBWLzOYIGuc7LSD1e/s3Cty5jgoyPc7lamEvPy2UYDXgWWq+mzApOkcOql7ODDNqwzGI/ENYMQM6DEcvvs7vHs57MvyO5WpZLzc8+qNc4ehPiKS6j76A08C54nIKuBc97kJNVExMPAFGPCcc8/I1/rAthV+pzKViGcdTVX1e4o//7GvV+s15Sz5WkjoAFOugdf6wqX/gvYX+Z3KVALl8m2jCXPNejntYHVbw3vD4NsnoaDA71QmzFnxMsFRo5HzTWS3K+DbJ+D9qyDH7o5nvGPFywRPdBxcMh76PQkrv4B/nwvbV/udyoQpK14muETgtJvg6qnw6zanIX/lV36nMmHIipfxRsuznHawmk3hnSHw3bN2w1sTVHZZG+OdWs3guq9g+i0w82HnRh+XvAxV7H7FAKrKwg27mJaawc8bslAUQRD3O3oBEDn4lb2IM07cGeTguEMvKBznPHWWdeRzDi5Pjpj/0DgOjpcjph9aX+Gyj17/kfkOrS86Unjqsm4n/qEFsOJlvFWlKvzhdWjYDb55CLavgssnQ+0WfifzzfLMPUxLzeCTtAzSd+0jJiqCU5rXJjrS+TNXnJ3Uwv1UdfdYC3dcFT00rEc8B7TAGYc7XgOXcdhynCeH1lPUsg+9NnCewOmFGfXgQg+97shlR0cG72DPipfxngj0HgP1O8GH18Jr5zg3/Wh1jt/Jys3Gnb8xPS2D6akZrNiSTWSE0Lt1XW4/ty3nd6pPfGy03xFDjmgItEMkJydrSkqK3zFMMOxYA+9dCdtXwHmPQK/RHHYsE0a2Ze/n00UZTE/LYOGGLACSm9ViYFIi/bs0pG71GH8DhgARWaCqyUVNsz0vU77qtILrv4b/3ARf/cVpBxv4gtPNIgzsycnlyyWZTE/L4L+rt1Og0L5BPPf0a8fFXRNpUruq3xHDhhUvU/5i4mHwm85J3bMfc/bChk6Gmk38TnZCcnLzmb18K9PTMpi5fCsH8gpoUjuOm85uxcBujWjXIN7viGHJipfxR0QEnHW3c8u1j0c5FzgcMgman+F3shLJyy9g7todTEvN4MslmWTvz6Nu9SoM69mUgUmJdG9S8+C3bsYbVryMv9pdCNfPdM6JfHMQXPAE9LyhQraDqSo/b8xiemoGMxZlsH3vAeJjorigcwMGJSXSq2UdooL4bZo5Nitexn8JbeGGmc4e2Od3Q2Ya9P87RMf6nQyAlVuymZa6ielpGWzcuY8qURH0bV+PQUmJnN2uHrHRkX5HrJSseJmKIbYGXP6uc1L3nKdg63IY+haclOhLnI07f+OTRU7XhuWZ2UQI9G5dlzF9na4NJ1nXBt9Z8TIVR0QE9PkLNOgCU2902sF+/wrUbonTfVtK/7MU827/9QCfL8nkk7RMUjZkoUD3JrV4eGAn+ndpSEK8dW2oSKyfl6mYtvzitIPt+p/fSQIUVfQiSlAYi3ttYJGNAIl0CrhEQkRkwDh3OCIy4PmxxkcUMV9x4wNef9iySrOM0oyPgianlPwTt35eJuTU7+ic2L3qayjIdc9p0cN/asHR46DoeQN+5ubns2brXpZl7GbN1mzyC5SacZF0aBhPhwbVSagec9xlHPPnCb2mwHkU5IPmuz8L3OGCgHEB0wryQQ8UMz5w/qJeX8z4om/mFTyRVeCB4NzK0IqXqbjiakLXwUFZVH6BMnfNDqanbeLzJZlk5+RRp1oVBiQ3ZGBSIj2a1rKuDXCokB5V6IorjKUsmEFkxcuELVUldWMW09MymLFoM9uy91M9JooLOjVgYFIivVtZ14ajiBw6zKvgrHiZsLN6azbTUjOYlprBhp2/USUygj7t6zEwKZE+7a1rQ7iw4mXCwqasfXyS5hSsZZv3ECFwequ63NKnNRd0akCNOOvaEG6seJmQtfPXA3y6eDPTUzcxf90uAJKa1OTBiztyUdeG1IuvGJ1cjTeseJmQsnd/Hl//ksm01Ay+X7WdvAKlTb3q3HV+Wy7ulkizOnaV1srCipep8Pbn5fN/K7YxPS2Db5ZtISe3gEY147j+zJYMSkqkfYN4+6awErLiZSqs9F2/8eKs1Xy2eDN7cvKoXa0Kg09uwiC3a0NEhBWsysyKl6mQFqVnce3EFH7dn8eFnd2uDa3rBvUa6Ca0WfEyFc7MZVu45Z2fqVO9Cu+N6k3renYxP3M0K16mQnlr7joenL6UTok1eH1Esn1jaIplxctUCAUFyt++WM6/5qzl3A71eOGK7lStYpunKZ5tHcZ3Obn53PlBGp8u2szVpzXjoYGdiLTGeHMcVryMr3b9eoAb3kwhZf0u/ty/PTec2dK6PZgSseJlfLN+x6+MnDCf9Kx9vDSsBxd1beh3JBNCrHgZX/y8YRfXT0ohX5V3rj+V5Oa1/Y5kQowVL1PuvlyayZj3fqZefCwTR55Cy4TqfkcyIciKlylXb3z/Px759Be6Na7J68OTqWO3vDcnyIqXKRf5Bcpjny7jjf/+jws61ef5od2Jq2LX1TInzoqX8VxObj63vZfKF0szGdm7Ofdf1NG6Qpgys+JlPLVj736ufzOF1I1Z/HVAR649o4XfkUyYsOJlPLN2215GTpxP5u4cxl95Mv06N/A7kgkjnp2iLyJviMhWEVkSMK62iHwtIqvcn7W8Wr/x14L1O/nD+B/Izsnj3VGnWeEyQefl9UUmAv2OGDcWmKmqbYCZ7nMTZj5dtJkrXvuJmlWrMPXm0+nR1P5HmeDzrHip6hxg5xGjBwGT3OFJwCVerd+UP1XltTlrGf3OQro2qsFHN51ul2U2ninvNq/6qrrZHc4E6hc3o4iMAkYBNG3atByimbLIL1DGfbKUSXPXc1GXhvx9SDe7xZjxlG+XpVQtvNd5sdNfVdVkVU1OSEgox2SmtH47kMcf31rApLnrGfW7lvzziu5WuIznynvPa4uINFTVzSLSENhazus3QbYtez/XTZrPkk27eWRQJ67u1dzvSKaSKO89r+nAcHd4ODCtnNdvgmj11r38/uX/smrLXl69OtkKlylXnu15ici7wNlAXRFJBx4EngSmiMh1wHpgiFfrN976ae0ORr21gOhI4f0/nkbXxjX9jmQqGc+Kl6peUcykvl6t05SPaambuPuDRTSpHcfEkT1pUruq35FMJWQ97E2JqSrj/28NT32xgp4tavPa1cnUqBrtdyxTSVnxMiWSl1/AA9OW8u68DQzslsjTg7sSE2XfKBr/WPEyx/Xr/jxueWchs1ds4+azW3HX+e3sbtXGd1a8zDFt2ZPDtRPnszwzm8d/34Vhp1qHYVMxWPEyxVq5JZsRb8wja18u/x6ezDnt6vkdyZiDrHiZIv2wejt/fHsBcdGRTPljLzo3quF3JGMOY8XLHOXjhenc+9EiWtStxoSRPWlUM87vSMYcxYqXOUhV+ees1Tz79UpOb1WH8VedTI046wphKiYrXgaA3PwC/jJ1MVNS0rm0RyOevLQrVaJ8O2/fmOOy4mXIzsnl5skL+W7Vdv7Utw23n9sGEesKYSo2K16V3Obd+xg5YT6rt+7lqT90ZcgpTfyOZEyJWPGqxJZt3sPICfPZuz+PCSNP4cw2dt00EzqseFVSc1Zu4+bJC6keE8UHN/aiQ8OT/I5kTKlY8aqEpqRs5M8fL6Z1vepMGHkKDWtYVwgTeqx4VSKqynPfrOKFmas4s01dXr6yB/Gx1hXChCYrXpXEgbwCxn68iI8XbmJIcmMe+30XoiOtK4QJXVa8KoHd+3K56e0F/LBmB3ee15Zb+rS2rhAm5FnxCnObsvYxcsI8/rf9V54d0o1LezT2O5IxQWHFK4wt2bSbayfOZ19uPpNG9uT01nX9jmRM0FjxClOzV2xl9OSF1KpahbevP5W29eP9jmRMUFnxCkPv/LSBB6YtoX2DeCaMOIV6J8X6HcmYoLPiFUYKCpRnvlrBy9+u4Zx2Cbw4rAfVYuxXbMKTbdlhYn9ePnd/sIjpaRkMO7Up4wZ2Isq6QpgwZsUrDGT9doBRby1g3v92cm+/9tx4VkvrCmHCnhWvELdx52+MmDCPjTv38Y/LkxiU1MjvSMaUCyteIWxRehbXTpxPbr7y1nU9ObVlHb8jGVNuwqp4fbk0k3s/WkRUhBAVEUFkhBAdKe5P53lUhBAVGTgtgugIZ56oSOd1Ue5wZETEwdcXvq5w2c68RS876uCy3BwBw1FFDB+2jMgipkVEHHWfxG9+2cKt7/5MnepVeG9UT1rXq+7Tp26MP8KqeDWqGcegbonkFij5+UpuQQH5BUpevpLnDufmq/uzgP25BeQW5JNfUODOc2jaoXkLyHOXkV/gLFO1/N+bCERHHCpwe/fn0bVRDf49/BQS4mPKP5AxPgur4tW5UY1yuUVXQUFAYQwojnn5hxe/w6YdMZx/1HglL7/gqAJ61OvcaTXiohn1u5ZUrRJWv0JjSsy2/BMQESHERET6HcOYSs06AhljQpIVL2NMSLLiZYwJSVa8jDEhyYqXMSYkWfEyxoQkK17GmJBkxcsYE5JE/TjXpZREZBuwvoSz1wW2exinorD3GV7sfRatmaomFDUhJIpXaYhIiqom+53Da/Y+w4u9z9Kzw0ZjTEiy4mWMCUnhWLxe9TtAObH3GV7sfZZS2LV5GWMqh3Dc8zLGVAJWvIwxISmsipeI9BORFSKyWkTG+p3HCyLyhohsFZElfmfxkog0EZHZIvKLiCwVkTF+Z/KCiMSKyDwRSXPf58N+Z/KSiESKyM8iMqOsywqb4iUikcBLwIVAR+AKEenobypPTAT6+R2iHOQBd6pqR+A0YHSY/j73A31UtRuQBPQTkdP8jeSpMcCyYCwobIoX0BNYraprVfUA8B4wyOdMQaeqc4CdfufwmqpuVtWF7nA2zgYfdjelVMde92m0+wjLb9FEpDFwEfDvYCwvnIpXI2BjwPN0wnBjr4xEpDnQHfjJ5yiecA+lUoGtwNeqGpbvE3geuAcoCMbCwql4mTAkItWBj4DbVHWP33m8oKr5qpoENAZ6ikhnnyMFnYgMALaq6oJgLTOcitcmoEnA88buOBOiRCQap3BNVtWP/c7jNVXNAmYTnm2avYGBIrIOp0mnj4i8XZYFhlPxmg+0EZEWIlIFuByY7nMmc4JERIDXgWWq+qzfebwiIgkiUtMdjgPOA5b7GsoDqnqfqjZW1eY4f5uzVPWqsiwzbIqXquYBtwBf4jTuTlHVpf6mCj4ReReYC7QTkXQRuc7vTB7pDVyN8x861X309zuUBxoCs0VkEc4/4K9VtczdCCoDOz3IGBOSwmbPyxhTuVjxMsaEJCtexpiQZMXLGBOSrHgZY0KSFS/jKRHJD+jqkBrMq32ISPNwv7qGKV6U3wFM2NvnnvpiTFDZnpfxhYisE5GnRGSxez2r1u745iIyS0QWichMEWnqjq8vIlPd616licjp7qIiReQ191pYX7m91BGRP7nXAlskIu/59DaNh6x4Ga/FHXHYODRg2m5V7QK8iHPFAYB/ApNUtSswGXjBHf8C8H/uda96AIVnT7QBXlLVTkAW8Ad3/Figu7ucG715a8ZP1sPeeEpE9qpq9SLGr8O5CN9a9wTsTFWtIyLbgYaqmuuO36yqdd27pjdW1f0By2iOczpNG/f5vUC0qj4qIl8Ae4H/AP8JuGaWCRO252X8pMUMl8b+gOF8DrXjXoRzZd0ewHwRsfbdMGPFy/hpaMDPue7wDzhXHQC4EvjOHZ4J3AQHL95Xo7iFikgE0ERVZwP3AjWAo/b+TGiz/0bGa3HuVUILfaGqhd0larlXU9gPXOGOuxWYICJ3A9uAke74McCr7lU08nEK2eZi1hkJvO0WOAFecK+VZcKItXkZX7htXsmqut3vLCY02WGjMSYk2Z6XMSYk2Z6XMSYkWfEyxoQkK17GmJBkxcsYE5KseBljQtL/A/w8USoPr020AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 576x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "training_vis(train_losses, valid_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 模型评估\n",
    "\n",
    "在测试集上评估模型效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T04:55:47.340416Z",
     "iopub.status.busy": "2021-11-29T04:55:47.339537Z",
     "iopub.status.idle": "2021-11-29T04:55:47.606447Z",
     "shell.execute_reply": "2021-11-29T04:55:47.607038Z",
     "shell.execute_reply.started": "2021-11-28T14:01:44.127754Z"
    },
    "papermill": {
     "duration": 2.453872,
     "end_time": "2021-11-29T04:55:47.607210",
     "exception": false,
     "start_time": "2021-11-29T04:55:45.153338",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载得分最高的模型\n",
    "checkpoint = torch.load('../input/ai-earth-model-weights/task05_model_weights.pth')\n",
    "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n",
    "model.load_state_dict(checkpoint['state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T04:55:51.849996Z",
     "iopub.status.busy": "2021-11-29T04:55:51.849073Z",
     "iopub.status.idle": "2021-11-29T04:55:51.851492Z",
     "shell.execute_reply": "2021-11-29T04:55:51.850969Z",
     "shell.execute_reply.started": "2021-11-28T14:06:59.931318Z"
    },
    "papermill": {
     "duration": 2.125413,
     "end_time": "2021-11-29T04:55:51.851629",
     "exception": false,
     "start_time": "2021-11-29T04:55:49.726216",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 测试集路径\n",
    "test_path = '../input/ai-earth-tests/'\n",
    "# 测试集标签路径\n",
    "test_label_path = '../input/ai-earth-tests-labels/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T04:55:56.429364Z",
     "iopub.status.busy": "2021-11-29T04:55:56.428800Z",
     "iopub.status.idle": "2021-11-29T04:55:58.115486Z",
     "shell.execute_reply": "2021-11-29T04:55:58.115007Z",
     "shell.execute_reply.started": "2021-11-28T14:07:13.415385Z"
    },
    "papermill": {
     "duration": 4.135325,
     "end_time": "2021-11-29T04:55:58.115667",
     "exception": false,
     "start_time": "2021-11-29T04:55:53.980342",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# 读取测试数据和测试数据的标签\n",
    "files = os.listdir(test_path)\n",
    "X_test = []\n",
    "y_test = []\n",
    "for file in files:\n",
    "    X_test.append(np.load(test_path + file))\n",
    "    y_test.append(np.load(test_label_path + file))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T04:56:02.560786Z",
     "iopub.status.busy": "2021-11-29T04:56:02.559461Z",
     "iopub.status.idle": "2021-11-29T04:56:02.587431Z",
     "shell.execute_reply": "2021-11-29T04:56:02.588024Z",
     "shell.execute_reply.started": "2021-11-28T14:07:17.046359Z"
    },
    "papermill": {
     "duration": 2.329175,
     "end_time": "2021-11-29T04:56:02.588201",
     "exception": false,
     "start_time": "2021-11-29T04:56:00.259026",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((103, 12, 24, 48, 1), (103, 24))"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test = np.array(X_test)[:, :, :, 19: 67, :1]\n",
    "y_test = np.array(y_test)\n",
    "X_test.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-11-29T04:56:07.675344Z",
     "iopub.status.busy": "2021-11-29T04:56:07.674488Z",
     "iopub.status.idle": "2021-11-29T04:56:07.682481Z",
     "shell.execute_reply": "2021-11-29T04:56:07.682895Z",
     "shell.execute_reply.started": "2021-11-28T14:07:31.503452Z"
    },
    "papermill": {
     "duration": 2.455352,
     "end_time": "2021-11-29T04:56:07.683041",
     "exception": false,
     "start_time": "2021-11-29T04:56:05.227689",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "testset = AIEarthDataset(X_test, y_test)\n",
    "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在测试集上评估模型效果\n",
    "model.eval()\n",
    "model.to(device)\n",
    "preds = np.zeros((len(y_test),24))\n",
    "for i, data in tqdm(enumerate(testloader)):\n",
    "    data, labels = data\n",
    "    data = data.to(device)\n",
    "    labels = labels.to(device)\n",
    "    pred = model(data, train=False)\n",
    "    preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
    "s = score(y_test, preds)\n",
    "print('Score: {:.3f}'.format(s))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "papermill": {
     "duration": null,
     "end_time": null,
     "exception": null,
     "start_time": null,
     "status": "pending"
    },
    "tags": []
   },
   "source": [
    "## 总结\n",
    "\n",
    "这一次的TOP方案没有自己设计模型，而是使用了目前时空序列预测领域现有的模型，另一组TOP选手“ailab”也使用了现有的模型PredRNN++，关于时空序列预测领域的一些比较经典的模型可以参考https://www.zhihu.com/column/c_1208033701705162752"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 作业\n",
    "\n",
    "该TOP方案中以sst作为预测目标，间接计算nino3.4指数，学有余力的同学可以尝试用SA-ConvLSTM模型直接预测nino3.4指数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参考文献\n",
    "\n",
    "1. 吴先生的队伍方案分享：https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.9.561d5330dF9lX1&postId=231465\n",
    "2. ailab团队思路分享：https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.15.561d5330dF9lX1&postId=210734"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 6708.571081,
   "end_time": "2021-11-29T04:56:15.789285",
   "environment_variables": {},
   "exception": true,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2021-11-29T03:04:27.218204",
   "version": "2.3.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
